#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
评测LLM模型的困惑度(PPL)
用法：
python submission_bench_ppl.py \
  --model_dir /workspace/models/Meta-Llama-3.1-8B \
  --dataset wikitext2 \
  --nsamples 128 \
  --seqlen 2048 \
  --batch_size 1 \
  --device cuda
"""

import argparse
import os
import torch
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm

# 支持的数据集和对应的HF路径
DATASETS = {
    "wikitext2": ("wikitext", "wikitext-2-raw-v1"),
    "c4": ("allenai/c4", {"validation": "en/c4-validation.00000-of-00008.json.gz"}),
    "ptb": ("ptb_text_only", "penn_treebank")
}

def load_dataset_split(dataset_name, split):
    """加载数据集的指定分割"""
    from datasets import load_dataset
    
    if dataset_name == "wikitext2":
        dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split=split)
        # WikiText格式化为单个连续文本
        text = "\n\n".join(dataset["text"])
        return text
    
    elif dataset_name == "ptb":
        dataset = load_dataset("ptb_text_only", "penn_treebank", split=split)
        text = " ".join(dataset["sentence"])
        return text
    
    elif dataset_name == "c4":
        # C4的验证集相对较大，我们只获取前10000个样本进行评估
        dataset = load_dataset(
            "allenai/c4",
            data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"},
            split="validation[:10000]"
        )
        text = "\n\n".join(dataset["text"])
        return text
    
    else:
        raise ValueError(f"Dataset {dataset_name} not supported")

def get_test_samples(text, tokenizer, n_samples, seq_len, seed):
    """从文本中提取评测样本"""
    np.random.seed(seed)
    encoded = tokenizer(text, return_tensors="pt")
    input_ids = encoded["input_ids"][0]
    
    # 确保文本足够长
    if input_ids.size(0) < seq_len * n_samples:
        # 如果文本太短，我们重复它
        repeats = (seq_len * n_samples // input_ids.size(0)) + 1
        input_ids = input_ids.repeat(repeats)
    
    # 计算可以采样的起始位置的最大索引
    max_start_idx = input_ids.size(0) - seq_len - 1
    
    # 随机选择n_samples个起始位置
    start_indices = np.random.randint(0, max_start_idx, size=n_samples)
    
    # 提取样本
    samples = []
    for start_idx in start_indices:
        end_idx = start_idx + seq_len
        sample = input_ids[start_idx:end_idx]
        samples.append(sample)
    
    return torch.stack(samples)

def main():
    parser = argparse.ArgumentParser(description="Benchmark PPL on datasets")
    parser.add_argument("--model_dir", type=str, required=True, help="Path to model directory")
    parser.add_argument("--dataset", type=str, default="wikitext2", choices=DATASETS.keys(), 
                        help="Dataset to evaluate on")
    parser.add_argument("--nsamples", type=int, default=128, help="Number of samples to evaluate")
    parser.add_argument("--seqlen", type=int, default=2048, help="Sequence length")
    parser.add_argument("--batch_size", type=int, default=1, help="Batch size for evaluation")
    parser.add_argument("--device", type=str, default="cuda", help="Device to run on")
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    
    args = parser.parse_args()
    
    # 加载模型和分词器
    try:
        tokenizer = AutoTokenizer.from_pretrained(args.model_dir)
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
    except Exception as e:
        print(f"Error loading tokenizer: {e}")
        return
    
    try:
        model = AutoModelForCausalLM.from_pretrained(
            args.model_dir,
            device_map=args.device,
            torch_dtype=torch.float16,
            low_cpu_mem_usage=True,
            trust_remote_code=True
        )
        model.eval()
    except Exception as e:
        print(f"Error loading model: {e}")
        return
    
    # 加载数据集
    try:
        text = load_dataset_split(args.dataset, "validation")
        samples = get_test_samples(text, tokenizer, args.nsamples, args.seqlen, args.seed)
    except Exception as e:
        print(f"Error loading dataset: {e}")
        return
    
    # 计算困惑度
    device = torch.device(args.device)
    n_samples = samples.shape[0]
    batches = [samples[i:i+args.batch_size] for i in range(0, n_samples, args.batch_size)]
    
    total_loss = 0
    total_tokens = 0
    
    with torch.no_grad():
        for batch in tqdm(batches, desc="bench_ppl"):
            batch = batch.to(device)
            
            # 计算概率
            outputs = model(batch, labels=batch)
            loss = outputs.loss
            
            # 计算非填充的token数量
            batch_size = batch.size(0)
            seq_len = batch.size(1) - 1  # -1 因为标签会被移位
            
            total_loss += loss.item() * batch_size * seq_len
            total_tokens += batch_size * seq_len
    
    # 计算最终困惑度
    avg_loss = total_loss / total_tokens
    ppl = torch.exp(torch.tensor(avg_loss)).item()
    
    print(f"ppl={ppl:.4f}")

if __name__ == "__main__":
    main()